{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "962e28eb",
   "metadata": {
    "origin_pos": 0
   },
   "source": [
    "# 编码器-解码器架构\n",
    ":label:`sec_encoder-decoder`\n",
    "\n",
    "正如我们在 :numref:`sec_machine_translation`中所讨论的，\n",
    "机器翻译是序列转换模型的一个核心问题，\n",
    "其输入和输出都是长度可变的序列。\n",
    "为了处理这种类型的输入和输出，\n",
    "我们可以设计一个包含两个主要组件的架构：\n",
    "第一个组件是一个*编码器*（encoder）：\n",
    "它接受一个长度可变的序列作为输入，\n",
    "并将其转换为具有固定形状的编码状态。\n",
    "第二个组件是*解码器*（decoder）：\n",
    "它将固定形状的编码状态映射到长度可变的序列。\n",
    "这被称为*编码器-解码器*（encoder-decoder）架构，\n",
    "如 :numref:`fig_encoder_decoder` 所示。\n",
    "\n",
    "![编码器-解码器架构](../img/encoder-decoder.svg)\n",
    ":label:`fig_encoder_decoder`\n",
    "\n",
    "我们以英语到法语的机器翻译为例：\n",
    "给定一个英文的输入序列：“They”“are”“watching”“.”。\n",
    "首先，这种“编码器－解码器”架构将长度可变的输入序列编码成一个“状态”，\n",
    "然后对该状态进行解码，\n",
    "一个词元接着一个词元地生成翻译后的序列作为输出：\n",
    "“Ils”“regordent”“.”。\n",
    "由于“编码器－解码器”架构是形成后续章节中不同序列转换模型的基础，\n",
    "因此本节将把这个架构转换为接口方便后面的代码实现。\n",
    "\n",
    "## (**编码器**)\n",
    "\n",
    "在编码器接口中，我们只指定长度可变的序列作为编码器的输入`X`。\n",
    "任何继承这个`Encoder`基类的模型将完成代码实现。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "17f77c60",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T07:05:48.406295Z",
     "iopub.status.busy": "2023-08-18T07:05:48.405469Z",
     "iopub.status.idle": "2023-08-18T07:05:49.653322Z",
     "shell.execute_reply": "2023-08-18T07:05:49.651979Z"
    },
    "origin_pos": 2,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "\n",
    "\n",
    "#@save\n",
    "class Encoder(nn.Module):\n",
    "    \"\"\"编码器-解码器架构的基本编码器接口\"\"\"\n",
    "    def __init__(self, **kwargs):\n",
    "        super(Encoder, self).__init__(**kwargs)\n",
    "\n",
    "    def forward(self, X, *args):\n",
    "        raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "de7f0caf",
   "metadata": {
    "origin_pos": 5
   },
   "source": [
    "## [**解码器**]\n",
    "\n",
    "在下面的解码器接口中，我们新增一个`init_state`函数，\n",
    "用于将编码器的输出（`enc_outputs`）转换为编码后的状态。\n",
    "注意，此步骤可能需要额外的输入，例如：输入序列的有效长度，\n",
    "这在 :numref:`subsec_mt_data_loading`中进行了解释。\n",
    "为了逐个地生成长度可变的词元序列，\n",
    "解码器在每个时间步都会将输入\n",
    "（例如：在前一时间步生成的词元）和编码后的状态\n",
    "映射成当前时间步的输出词元。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5c7a6471",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T07:05:49.659889Z",
     "iopub.status.busy": "2023-08-18T07:05:49.659020Z",
     "iopub.status.idle": "2023-08-18T07:05:49.666360Z",
     "shell.execute_reply": "2023-08-18T07:05:49.665230Z"
    },
    "origin_pos": 7,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "#@save\n",
    "class Decoder(nn.Module):\n",
    "    \"\"\"编码器-解码器架构的基本解码器接口\"\"\"\n",
    "    def __init__(self, **kwargs):\n",
    "        super(Decoder, self).__init__(**kwargs)\n",
    "\n",
    "    def init_state(self, enc_outputs, *args):\n",
    "        raise NotImplementedError\n",
    "\n",
    "    def forward(self, X, state):\n",
    "        raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6e0548de",
   "metadata": {
    "origin_pos": 10
   },
   "source": [
    "## [**合并编码器和解码器**]\n",
    "\n",
    "总而言之，“编码器-解码器”架构包含了一个编码器和一个解码器，\n",
    "并且还拥有可选的额外的参数。\n",
    "在前向传播中，编码器的输出用于生成编码状态，\n",
    "这个状态又被解码器作为其输入的一部分。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "53fb0929",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T07:05:49.671685Z",
     "iopub.status.busy": "2023-08-18T07:05:49.670944Z",
     "iopub.status.idle": "2023-08-18T07:05:49.678831Z",
     "shell.execute_reply": "2023-08-18T07:05:49.677718Z"
    },
    "origin_pos": 12,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "#@save\n",
    "class EncoderDecoder(nn.Module):\n",
    "    \"\"\"编码器-解码器架构的基类\"\"\"\n",
    "    def __init__(self, encoder, decoder, **kwargs):\n",
    "        super(EncoderDecoder, self).__init__(**kwargs)\n",
    "        self.encoder = encoder\n",
    "        self.decoder = decoder\n",
    "\n",
    "    def forward(self, enc_X, dec_X, *args):\n",
    "        enc_outputs = self.encoder(enc_X, *args)\n",
    "        dec_state = self.decoder.init_state(enc_outputs, *args)\n",
    "        return self.decoder(dec_X, dec_state)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dce5eb8e",
   "metadata": {
    "origin_pos": 15
   },
   "source": [
    "“编码器－解码器”体系架构中的术语*状态*\n",
    "会启发人们使用具有状态的神经网络来实现该架构。\n",
    "在下一节中，我们将学习如何应用循环神经网络，\n",
    "来设计基于“编码器－解码器”架构的序列转换模型。\n",
    "\n",
    "## 小结\n",
    "\n",
    "* “编码器－解码器”架构可以将长度可变的序列作为输入和输出，因此适用于机器翻译等序列转换问题。\n",
    "* 编码器将长度可变的序列作为输入，并将其转换为具有固定形状的编码状态。\n",
    "* 解码器将具有固定形状的编码状态映射为长度可变的序列。\n",
    "\n",
    "## 练习\n",
    "\n",
    "1. 假设我们使用神经网络来实现“编码器－解码器”架构，那么编码器和解码器必须是同一类型的神经网络吗？\n",
    "1. 除了机器翻译，还有其它可以适用于”编码器－解码器“架构的应用吗？\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "99846b42",
   "metadata": {
    "origin_pos": 17,
    "tab": [
     "pytorch"
    ]
   },
   "source": [
    "[Discussions](https://discuss.d2l.ai/t/2779)\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  },
  "required_libs": []
 },
 "nbformat": 4,
 "nbformat_minor": 5
}